{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "9Dr8xIsuTn9i" }, "source": [ "# GAT implementation\n", "Graphic Attention Network\n", "\n", "Official resources from [Blog](https://dsgiitr.com/blogs/gat/)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "Sskfh5PeTn9q" }, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "markdown", "metadata": { "id": "kKKavjKGTn9r" }, "source": [ "## Structure" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "5Jj-SdfVTn9s" }, "outputs": [], "source": [ "class GATLayer(nn.Module):\n", " \"\"\"\n", " Simple PyTorch Implementation of the Graph Attention layer.\n", " \"\"\"\n", " def __init__(self):\n", " super(GATLayer, self).__init__()\n", " \n", " def forward(self, input, adj):\n", " print(\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "rN_8zBYaTn9t" }, "source": [ "## Let's start from the forward method" ] }, { "cell_type": "markdown", "metadata": { "id": "hRTX0Rh0Tn9t" }, "source": [ "### Linear Transformation\n", "\n", "$$\n", "\\bar{h'}_i = \\textbf{W}\\cdot \\bar{h}_i\n", "$$\n", "with $\\textbf{W}\\in\\mathbb R^{F'\\times F}$ and $\\bar{h}_i\\in\\mathbb R^{F}$.\n", "\n", "$$\n", "\\bar{h'}_i \\in \\mathbb{R}^{F'}\n", "$$" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "n0p69T_NTn9u", "outputId": "af193778-f308-4850-83b4-73d59fbc55c5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 2])\n" ] } ], "source": [ "in_features = 5\n", "out_features = 2\n", "nb_nodes = 3\n", "\n", "W = nn.Parameter(torch.zeros(size=(in_features, out_features))) #xavier paramiter inizializator\n", "nn.init.xavier_uniform_(W.data, gain=1.414)\n", "\n", "input = torch.rand(nb_nodes,in_features) \n", "\n", "\n", "# linear transformation\n", "h = torch.mm(input, W)\n", "N = h.size()[0]\n", "\n", "print(h.shape)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0XVLFYXtVK5h", "outputId": "f3f67fed-c835-4b54-a506-4d6305ff7198" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 5])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input.size()" ] }, { "cell_type": "markdown", "metadata": { "id": "jrWAPxg9Tn9u" }, "source": [ "### Attention Mechanism" ] }, { "cell_type": "markdown", "metadata": { "id": "ocxSDyLvTn9v" }, "source": [ "![title](https://github.com/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial3/AttentionMechanism.png?raw=1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0oPiXvoTTn9v", "outputId": "2e78f6dd-7f35-40f2-b037-fb59d1456625" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([4, 1])\n" ] } ], "source": [ "a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) #xavier paramiter inizializator\n", "nn.init.xavier_uniform_(a.data, gain=1.414)\n", "print(a.shape)\n", "\n", "leakyrelu = nn.LeakyReLU(0.2) # LeakyReLU" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "hI3VJ_bBTn9v" }, "outputs": [], "source": [ "a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * out_features)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Pwzr-VxcYOY3", "outputId": "c142f630-424f-42b6-dc51-5be5609d03e9" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 3, 4])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a_input.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "ppoK_cbmTn9w" }, "source": [ "![title](https://github.com/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial3/a_input.png?raw=1)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "9XfnXgfuTn9w" }, "outputs": [], "source": [ "e = leakyrelu(torch.matmul(a_input, a).squeeze(2))" ] }, { "cell_type": "markdown", "metadata": { "id": "SAvRBFcNdEK0" }, "source": [ "Row $i$ of e is the coeffcients in for row $i$. Thus, we will get a $N*N$ matrix, where $N$ is the number of nodes." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KM337YyBTn9w", "outputId": "15d25a07-d3f8-4503-80d0-e7e815fa3167" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 3, 4]) torch.Size([4, 1])\n", "\n", "torch.Size([3, 3, 1])\n", "\n", "torch.Size([3, 3])\n" ] } ], "source": [ "print(a_input.shape,a.shape)\n", "print(\"\")\n", "print(torch.matmul(a_input,a).shape)\n", "print(\"\")\n", "print(torch.matmul(a_input,a).squeeze(2).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "M3U302C3Tn9w" }, "source": [ "### Masked Attention" ] }, { "cell_type": "markdown", "metadata": { "id": "XQU8WB_rds8i" }, "source": [ "Since $e_{ij}$ are computed for all pairs of in this $3\\times3$ matrix, we need to mask out those coefficients for those not in the neighborhood of each node, i.e., only keep coefficients that correspond to edges in graph." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wnCBnwGWTn9x", "outputId": "43e65b97-a7f2-4d69-eed9-b3df2684167a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 3])\n" ] } ], "source": [ "# Masked Attention\n", "adj = torch.randint(2, (3, 3))\n", "\n", "zero_vec = -9e15*torch.ones_like(e)\n", "print(zero_vec.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "DRUtOxpon8WM" }, "source": [ "We use $-9e15$ as the zero entries, because we will perform exponential operation on $e_{i,j}$ later and a small enough negative number will produce zero on exponent." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mbopx7bkTn9x", "outputId": "fc0b6c6a-b097-4c9f-e17d-21f58b68dc2e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[1, 1, 1],\n", " [1, 0, 1],\n", " [1, 0, 0]]) \n", " tensor([[-0.3351, -0.2840, -0.3298],\n", " [-0.2346, -0.1835, -0.2293],\n", " [-0.2771, -0.2260, -0.2718]], grad_fn=) \n", " tensor([[-9.0000e+15, -9.0000e+15, -9.0000e+15],\n", " [-9.0000e+15, -9.0000e+15, -9.0000e+15],\n", " [-9.0000e+15, -9.0000e+15, -9.0000e+15]])\n" ] }, { "data": { "text/plain": [ "tensor([[-3.3511e-01, -2.8395e-01, -3.2975e-01],\n", " [-2.3463e-01, -9.0000e+15, -2.2928e-01],\n", " [-2.7714e-01, -9.0000e+15, -9.0000e+15]], grad_fn=)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attention = torch.where(adj > 0, e, zero_vec)\n", "print(adj,\"\\n\",e,\"\\n\",zero_vec)\n", "attention" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "Uo85YqppTn9y" }, "outputs": [], "source": [ "attention = F.softmax(attention, dim=1) # softmax over columns(each row vector) \n", "h_prime = torch.matmul(attention, h)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gKWO3NItTn9y", "outputId": "87fe0783-75d2-4af7-a360-5cd41e96993a" }, "outputs": [ { "data": { "text/plain": [ "tensor([[0.3270, 0.3442, 0.3288],\n", " [0.4987, 0.0000, 0.5013],\n", " [1.0000, 0.0000, 0.0000]], grad_fn=)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attention" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "P8BpTbdiTn9z", "outputId": "b7d8e381-14d9-4db3-da07-c0270f18d2fd" }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.5785, 0.6305],\n", " [-0.7469, 0.5793],\n", " [-0.9134, 0.4681]], grad_fn=)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h_prime" ] }, { "cell_type": "markdown", "metadata": { "id": "8-5txR6-Tn9z" }, "source": [ "#### h_prime vs h" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VLDJeeEaTn9z", "outputId": "3baeea8b-c6c9-4b0e-9e23-9c90e4bb08eb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.5785, 0.6305],\n", " [-0.7469, 0.5793],\n", " [-0.9134, 0.4681]], grad_fn=) \n", " tensor([[-0.9134, 0.4681],\n", " [-0.2576, 0.7279],\n", " [-0.5813, 0.6900]], grad_fn=)\n" ] } ], "source": [ "print(h_prime,\"\\n\",h)" ] }, { "cell_type": "markdown", "metadata": { "id": "w0WakOlnTn90" }, "source": [ "# Build the layer" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "UEWAEFgyTn90" }, "outputs": [], "source": [ "class GATLayer(nn.Module):\n", " def __init__(self, in_features, out_features, dropout, alpha, concat=True):\n", " super(GATLayer, self).__init__()\n", " \n", " '''\n", " TODO\n", " '''\n", " \n", " def forward(self, input, adj):\n", " # Linear Transformation\n", " h = torch.mm(input, self.W) # matrix multiplication\n", " N = h.size()[0]\n", "\n", " # Attention Mechanism\n", " a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)\n", " e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))\n", "\n", " # Masked Attention\n", " zero_vec = -9e15*torch.ones_like(e)\n", " attention = torch.where(adj > 0, e, zero_vec)\n", " \n", " attention = F.softmax(attention, dim=1)\n", " attention = F.dropout(attention, self.dropout, training=self.training)\n", " h_prime = torch.matmul(attention, h)\n", "\n", " if self.concat:\n", " return F.elu(h_prime)\n", " else:\n", " return h_prime" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "ABniBOkqTn91" }, "outputs": [], "source": [ "class GATLayer(nn.Module):\n", " def __init__(self, in_features, out_features, dropout, alpha, concat=True):\n", " super(GATLayer, self).__init__()\n", " self.dropout = dropout # drop prob = 0.6\n", " self.in_features = in_features # \n", " self.out_features = out_features # \n", " self.alpha = alpha # LeakyReLU with negative input slope, alpha = 0.2\n", " self.concat = concat # conacat = True for all layers except the output layer.\n", "\n", " \n", " # Xavier Initialization of Weights\n", " # Alternatively use weights_init to apply weights of choice \n", " self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))\n", " nn.init.xavier_uniform_(self.W.data, gain=1.414)\n", " \n", " self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))\n", " nn.init.xavier_uniform_(self.a.data, gain=1.414)\n", " \n", " # LeakyReLU\n", " self.leakyrelu = nn.LeakyReLU(self.alpha)\n", "\n", " def forward(self, input, adj):\n", " # Linear Transformation\n", " h = torch.mm(input, self.W) # matrix multiplication\n", " N = h.size()[0]\n", " print(N)\n", "\n", " # Attention Mechanism\n", " a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)\n", " e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))\n", "\n", " # Masked Attention\n", " zero_vec = -9e15*torch.ones_like(e)\n", " attention = torch.where(adj > 0, e, zero_vec)\n", " \n", " attention = F.softmax(attention, dim=1)\n", " attention = F.dropout(attention, self.dropout, training=self.training)\n", " h_prime = torch.matmul(attention, h)\n", "\n", " if self.concat:\n", " return F.elu(h_prime)\n", " else:\n", " return h_prime" ] }, { "cell_type": "markdown", "metadata": { "id": "_k2TwFK2Tn92" }, "source": [ "# Use it" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 398 }, "id": "gpoLl05QTn92", "outputId": "7b18f047-33df-4bb8-f434-f9565e02427c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of Classes in Cora: 7\n", "Number of Node Features in Cora: 1433\n" ] } ], "source": [ "from torch_geometric.data import Data\n", "from torch_geometric.nn import GATConv\n", "from torch_geometric.datasets import Planetoid\n", "import torch_geometric.transforms as T\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "name_data = 'Cora'\n", "dataset = Planetoid(root= '/tmp/' + name_data, name = name_data)\n", "dataset.transform = T.NormalizeFeatures()\n", "\n", "print(f\"Number of Classes in {name_data}:\", dataset.num_classes)\n", "print(f\"Number of Node Features in {name_data}:\", dataset.num_node_features)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "g9sPNLyXTn93" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.9457, grad_fn=)\n", "tensor(0.6872, grad_fn=)\n", "tensor(0.6268, grad_fn=)\n", "tensor(0.6055, grad_fn=)\n", "tensor(0.5019, grad_fn=)\n" ] } ], "source": [ "class GAT(torch.nn.Module):\n", " def __init__(self):\n", " super(GAT, self).__init__()\n", " self.hid = 8\n", " self.in_head = 8\n", " self.out_head = 1\n", " \n", " \n", " self.conv1 = GATConv(dataset.num_features, self.hid, heads=self.in_head, dropout=0.6)\n", " self.conv2 = GATConv(self.hid*self.in_head, dataset.num_classes, concat=False,\n", " heads=self.out_head, dropout=0.6)\n", "\n", " def forward(self, data):\n", " x, edge_index = data.x, data.edge_index\n", " \n", " x = F.dropout(x, p=0.6, training=self.training)\n", " x = self.conv1(x, edge_index)\n", " x = F.elu(x)\n", " x = F.dropout(x, p=0.6, training=self.training)\n", " x = self.conv2(x, edge_index)\n", " \n", " return F.log_softmax(x, dim=1)\n", " \n", " \n", " \n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device = \"cpu\"\n", "\n", "model = GAT().to(device)\n", "data = dataset[0].to(device)\n", "\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)\n", "\n", "model.train()\n", "for epoch in range(1000):\n", " model.train()\n", " optimizer.zero_grad()\n", " out = model(data)\n", " loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n", " \n", " if epoch%200 == 0:\n", " print(loss)\n", " \n", " loss.backward()\n", " optimizer.step()\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "vPJMnlXcTn93" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.8110\n" ] } ], "source": [ "model.eval()\n", "_, pred = model(data).max(dim=1)\n", "correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())\n", "acc = correct / data.test_mask.sum().item()\n", "print('Accuracy: {:.4f}'.format(acc))" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "“Tutorial3.ipynb”的副本", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 1 }